# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional

from flask import Flask
from sqlalchemy import text, TypeDecorator
from sqlalchemy.engine import Connection, Dialect, Row
from sqlalchemy_utils import EncryptedType

logger = logging.getLogger(__name__)


class AbstractEncryptedFieldAdapter(ABC):  # pylint: disable=too-few-public-methods
    @abstractmethod
    def create(
        self,
        app_config: Optional[Dict[str, Any]],
        *args: List[Any],
        **kwargs: Optional[Dict[str, Any]],
    ) -> TypeDecorator:
        pass


class SQLAlchemyUtilsAdapter(  # pylint: disable=too-few-public-methods
    AbstractEncryptedFieldAdapter
):
    def create(
        self,
        app_config: Optional[Dict[str, Any]],
        *args: List[Any],
        **kwargs: Optional[Dict[str, Any]],
    ) -> TypeDecorator:
        if app_config:
            return EncryptedType(*args, app_config["SECRET_KEY"], **kwargs)

        raise Exception("Missing app_config kwarg")


class EncryptedFieldFactory:
    def __init__(self) -> None:
        self._concrete_type_adapter: Optional[AbstractEncryptedFieldAdapter] = None
        self._config: Optional[Dict[str, Any]] = None

    def init_app(self, app: Flask) -> None:
        self._config = app.config
        self._concrete_type_adapter = self._config[
            "SQLALCHEMY_ENCRYPTED_FIELD_TYPE_ADAPTER"
        ]()

    def create(
        self, *args: List[Any], **kwargs: Optional[Dict[str, Any]]
    ) -> TypeDecorator:
        if self._concrete_type_adapter:
            return self._concrete_type_adapter.create(self._config, *args, **kwargs)

        raise Exception("App not initialized yet. Please call init_app first")


class SecretsMigrator:
    def __init__(self, previous_secret_key: str) -> None:
        from superset import db  # pylint: disable=import-outside-toplevel

        self._db = db
        self._previous_secret_key = previous_secret_key
        self._dialect: Dialect = db.engine.url.get_dialect()

    def discover_encrypted_fields(self) -> Dict[str, Dict[str, EncryptedType]]:
        """
        Iterates over SqlAlchemy's metadata, looking for EncryptedType
        columns along the way. Builds up a dict of
        table_name -> dict of col_name: enc type instance
        :return:
        """
        meta_info: Dict[str, Any] = {}

        for table_name, table in self._db.metadata.tables.items():
            for col_name, col in table.columns.items():
                if isinstance(col.type, EncryptedType):
                    cols = meta_info.get(table_name, {})
                    cols[col_name] = col.type
                    meta_info[table_name] = cols

        return meta_info

    @staticmethod
    def _read_bytes(col_name: str, value: Any) -> Optional[bytes]:
        if value is None or isinstance(value, bytes):
            return value
        # Note that the Postgres Driver returns memoryview's for BLOB types
        if isinstance(value, memoryview):
            return value.tobytes()
        if isinstance(value, str):
            return bytes(value.encode("utf8"))

        # Just bail if we haven't seen this type before...
        raise ValueError(f"DB column {col_name} has unknown type: {type(value)}")

    @staticmethod
    def _select_columns_from_table(
        conn: Connection, column_names: List[str], table_name: str
    ) -> Row:
        return conn.execute(f"SELECT id, {','.join(column_names)} FROM {table_name}")

    def _re_encrypt_row(
        self,
        conn: Connection,
        row: Row,
        table_name: str,
        columns: Dict[str, EncryptedType],
    ) -> None:
        """
        Re encrypts all columns in a Row
        :param row: Current row to reencrypt
        :param columns: Meta info from columns
        """
        re_encrypted_columns = {}

        for column_name, encrypted_type in columns.items():
            previous_encrypted_type = EncryptedType(
                type_in=encrypted_type.underlying_type, key=self._previous_secret_key
            )
            try:
                unencrypted_value = previous_encrypted_type.process_result_value(
                    self._read_bytes(column_name, row[column_name]), self._dialect
                )
            except ValueError as exc:
                # Failed to unencrypt
                try:
                    encrypted_type.process_result_value(
                        self._read_bytes(column_name, row[column_name]), self._dialect
                    )
                    logger.info(
                        "Current secret is able to decrypt value on column [%s.%s],"
                        " nothing to do",
                        table_name,
                        column_name,
                    )
                    return
                except Exception:
                    raise Exception from exc

            re_encrypted_columns[column_name] = encrypted_type.process_bind_param(
                unencrypted_value,
                self._dialect,
            )

        set_cols = ",".join(
            [f"{name} = :{name}" for name in list(re_encrypted_columns.keys())]
        )
        logger.info("Processing table: %s", table_name)
        conn.execute(
            text(f"UPDATE {table_name} SET {set_cols} WHERE id = :id"),
            id=row["id"],
            **re_encrypted_columns,
        )

    def run(self) -> None:
        encrypted_meta_info = self.discover_encrypted_fields()

        with self._db.engine.begin() as conn:
            logger.info("Collecting info for re encryption")
            for table_name, columns in encrypted_meta_info.items():
                column_names = list(columns.keys())
                rows = self._select_columns_from_table(conn, column_names, table_name)

                for row in rows:
                    self._re_encrypt_row(conn, row, table_name, columns)
        logger.info("All tables processed")
